import numpy as np
from VMBPO import VMBPO
from stochastic_env import stochastic_env
from utils import compute_initial_state_value
from stochastic_cliff_env import stochastic_cliff
import matplotlib.pyplot as plt
import time

env = stochastic_cliff_env()
p_s_given_sa, r_sa = env.generate_env()

episodes = 40000
iterations = 200

learning_rate = 0.1

algo = "VMBPO"
algo = "beta_VMBPO"
algo = "MnM"

beta = 1
print('Beta:{0}'.format(beta))
average_return = np.zeros((int(episodes/20),5))

for i in range(5):
    learning_rate = 0.1
    b = beta
    print(i)
    start_time = time.time()
    algorithm = VMBPO(env.num_states, env.num_actions, env)
    return_per_episode = []
    adaptive_beta = []
    for e in range(episodes):
        if(e % 5000 == 4999):
            learning_rate = learning_rate / 2
        if(e % 1000 == 999):
            print(e)
        if e % 20 == 0:
            return_per_episode.append(
                compute_initial_state_value(
                    env.num_states, env.num_actions, algorithm.Q, p_s_given_sa, r_sa))
        s = 0

        V = np.max(algorithm.Q, axis=-1)
        q_s_given_sa = p_s_given_sa * np.exp(V)
        q_s_given_sa = q_s_given_sa / q_s_given_sa.sum(-1)[:, :, None]

        # if(e != 0):
        #     print(q_s_given_sa)
        #     print(p_s_given_sa)

        try:
            kl = np.nansum(q_s_given_sa * (np.log(q_s_given_sa) - np.log(p_s_given_sa)), axis=-1)
        except:
            print(q_s_given_sa)
            print(p_s_given_sa)

        for t in range(iterations):
            a = algorithm.epsilon_greedy(s)
            try:
                next_s = np.random.choice(env.num_states, p=q_s_given_sa[s, a])
            except:
                print(q_s_given_sa)
                print(p_s_given_sa)
            if(algo == "MnM"):
                algorithm.update_step(s, a, next_s, np.log(r_sa[s, a]/b) - kl[s,a])
            else:
                algorithm.update_step(s, a, next_s, r_sa[s, a]/b - kl[s,a])
            if (s == 9):
                # print(e,found_end)
                break
            s = next_s
            # print(b)
            if(algo == "beta_VMBPO"):
                if(e != 0):
                    b -= learning_rate * (0.01 - kl[s,a])
                    if(b < 0):
                        b = 0.1
            # print(b)
            adaptive_beta.append(np.log10(b))
    average_return[:,i] += np.array(return_per_episode)
    print(time.time()-start_time)



# plt.imshow(np.max(algorithm.Q, axis=1).reshape(-1, 10))
# plt.show()
episode_axis = np.arange(1, episodes, 20)
mean = np.average(average_return,axis=1)
std = np.std(average_return, axis=1)

np.save('{0}_mean'.format(algorithm),mean)
np.save('{0}_std'.format(algorithm), std)
#
# np.save('n_{0}_mean'.format(beta),mean)
# np.save('n_{0}_std'.format(beta), std)
# np.save('n_{0}_beta'.format(beta), adaptive_beta)
